{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Intitutive Explanation of Expectation Maximization Algorithm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Probabilistic models, such as hidden Markov models or Bayesian networks, are commonly used to model biological data. Much\n",
    "of their popularity can be attributed to the existence of efficient and robust procedures for learning parameters from observations. \n",
    "\n",
    "Often, however, the only data available for training a probabilistic model are incomplete. Missing values can occur, for example, in medical diagnosis, where patient histories generally include results from a limited battery of tests.\n",
    "\n",
    "Alternatively, in gene expression clustering, incomplete data arise from the intentional omission of gene-to-cluster assignments in the probabilistic model. \n",
    "\n",
    "**The expectation maximization algorithm enables parameter estimation in probabilistic models with incomplete data.**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### What is EM?\n",
    "\n",
    "- The Expectation-Maximization (EM) algorithm is a way to find **maximum-likelihood estimates** for model parameters when your data is incomplete, has missing data points, or has unobserved (hidden) latent variables.\n",
    "\n",
    "- It is an **iterative way to approximate the maximum likelihood function.**\n",
    "\n",
    "- While maximum likelihood estimation can find the “best fit” model for a set of data, it doesn’t work particularly well for incomplete data sets.\n",
    "\n",
    "- The more complex EM algorithm can find model parameters even if you have missing data. It works by choosing random values for the missing data points, and using those guesses to estimate a second set of data. The new values are used to create a better guess for the first set, and the process continues until the algorithm converges on a fixed point."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### How EM is different from MLE?\n",
    "\n",
    "Although Maximum Likelihood Estimation (MLE) and EM can both find “best-fit” parameters, how they find the models are very different.\n",
    "\n",
    "- MLE accumulates all of the data first and then uses that data to construct the most likely model.\n",
    "- EM takes a guess at the parameters first — accounting for the missing data — then tweaks the model to fit the guesses and the observed data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### How EM works?\n",
    "\n",
    "The basic steps for the algorithm are:\n",
    "\n",
    "1. An initial guess is made for the model’s parameters and a probability distribution is created. This is sometimes called the “E-Step” for the “Expected” distribution.\n",
    "2. Newly observed data is fed into the model.\n",
    "3. The probability distribution from the E-step is tweaked to include the new data. This is sometimes called the “M-step.”\n",
    "4. Steps 2 through 4 are repeated until stability (i.e. a distribution that doesn’t change from the E-step to the M-step) is reached."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This will try to use very less math and just give intution, how EM converges."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Expectation Maximization using Two Coin example: "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- We have two biased coin and we dont know probability of getting head on each coins. Assume $\\theta_A$ and $\\theta{_B}$\n",
    "- We randomly choose one coin and toss it 10 times and repeat this experiment five times.\n",
    "- A simple way to estimate $\\theta_A$ and $\\theta_B$ is to return the observed proportions of heads for each coin:\n",
    "\n",
    "  $$\\hat{\\theta_A} = \\frac{Number ~of~ heads ~using ~coin~ A}{total ~number~ flips~ using~ coin ~A }$$ \n",
    "  \n",
    "  $$\\hat{\\theta_B} = \\frac{Number~ of~ heads ~using ~coin ~B}{total number~ flips ~using~ coin~ B }$$\n",
    "  \n",
    "  \n",
    "This intuitive guess is, in fact, known in the statistical literature as <font color ='red'>Maximum Likelihood Estimation </font>. **As shown in below figure (a)**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"https://i.stack.imgur.com/mj0nb.gif\" width=\"500\" height=\"500\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**More Complicated Problem**:\n",
    "- If in experiment we just know count of head, but which coin is used for experiment is not known.\n",
    "- How to estimated probablities of head for each coin? This problem can be solved using <font color='red'>Expectation Maximization.</font>\n",
    "\n",
    "Steps are shown in above **figure (b)**:\n",
    "    1. Assume some probablity of head for both the coins.\n",
    "    2. Create the probability distribution (using binomial distribution). (Expectation)\n",
    "        Example: for second experiment we have 9 heads and 1 tail.\n",
    "        \n",
    "  $$p(A) = \\theta_A^9(1-\\theta_A)^{10-9} \\approx 0.004$$\n",
    "  \n",
    "  $$p(B) = \\theta_B^9(1-\\theta_B)^{10-9} \\approx 0.001$$\n",
    "       \n",
    "       This give us:\n",
    "  $$\\frac{0.004}{0.004+0.001}=0.8$$\n",
    "  \n",
    "  $$\\frac{0.001}{0.004+0.001}=0.2$$\n",
    "    \n",
    "  Similiary we can calculate the distribution for Coin A and Coin B for each experiment.\n",
    "  \n",
    "    3a. Based on distribution calculate the new estimated probability. (Maximization)\n",
    "    3b. Repeat above two steps.\n",
    "    4. Give the estimated probablity after convergence.\n",
    "    \n",
    "Theoritically it is proven the EM always converges [ref](https://www.cmi.ac.in/~madhavan/courses/dmml2018/literature/EM_algorithm_2coin_example.pdf)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Expecation maximization for Nomral distribution"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this experiment we will try to find mean and std of two normal distribution."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Lets generate some data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "from scipy import stats\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "np.random.seed(110)  # for reproducible random results\n",
    "\n",
    "# set parameters\n",
    "red_mean = 3\n",
    "red_std = 0.8\n",
    "\n",
    "blue_mean = 7\n",
    "blue_std = 2\n",
    "\n",
    "# draw 20 samples from normal distributions with red/blue parameters\n",
    "red = np.random.normal(red_mean, red_std, size=20)\n",
    "blue = np.random.normal(blue_mean, blue_std, size=20)\n",
    "\n",
    "both_colours = np.sort(np.concatenate((red, blue)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.mlab as mlab\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.rcParams[\"figure.figsize\"] = (20, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = [1 for i in range(len(red))]\n",
    "plt.plot(red, y, \"ro\", blue, y, \"bo\")\n",
    "plt.title(\"Data(With Label)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**If label is given it is easy problem and we can calucate the mean and std using MLE**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Red Mean is {}\".format(np.mean(red)))\n",
    "print(\"Red Std is {}\".format(np.std(red)))\n",
    "print(\"\")\n",
    "print(\"Blue Mean is {}\".format(np.mean(blue)))\n",
    "print(\"Blue Std is {}\".format(np.std(blue)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**If label is not given. USE EM** "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = [1 for i in range(len(both_colours))]\n",
    "plt.plot(both_colours, y, \"mo\")\n",
    "plt.title(\"Data(Without Label)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## STEP 1 (Guess)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# estimates for the mean\n",
    "red_mean_guess = 1.1\n",
    "blue_mean_guess = 9\n",
    "\n",
    "# estimates for the standard deviation\n",
    "red_std_guess = 2\n",
    "blue_std_guess = 1.7"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot will look like"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = np.linspace(-3, 12, 100)\n",
    "plt.plot(\n",
    "    x,\n",
    "    mlab.normpdf(x, red_mean_guess, red_std_guess),\n",
    "    \"r\",\n",
    "    x,\n",
    "    mlab.normpdf(x, blue_mean_guess, blue_std_guess),\n",
    "    \"b\",\n",
    "    alpha=1,\n",
    ")\n",
    "plt.ylim(-0.02, 0.28)\n",
    "y = [0.001 for i in range(len(both_colours))]\n",
    "plt.plot(both_colours, y, \"mo\")\n",
    "ymaxr = max(mlab.normpdf(x, red_mean_guess, red_std_guess)) / 0.28\n",
    "ymaxb = max(mlab.normpdf(x, blue_mean_guess, blue_std_guess)) / 0.28\n",
    "plt.axvline(red_mean_guess, color=\"r\", linestyle=\"--\", ymin=0.067, ymax=ymaxr)\n",
    "plt.axvline(blue_mean_guess, color=\"b\", linestyle=\"--\", ymin=0.067, ymax=ymaxb);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2.\n",
    "### Likelihood of each point under current guess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)\n",
    "likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\n",
    "    \"likelihood of point {} being red {}\".format(both_colours[1], likelihood_of_red[1])\n",
    ")\n",
    "print(\n",
    "    \"likelihood of point {} being blue {}\".format(\n",
    "        both_colours[1], likelihood_of_blue[1]\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3\n",
    "### Convert Likelihood into Weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "likelihood_total = likelihood_of_red + likelihood_of_blue\n",
    "\n",
    "red_weight = likelihood_of_red / likelihood_total\n",
    "blue_weight = likelihood_of_blue / likelihood_total"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4\n",
    "\n",
    "### With current estimates and new weights we can compute new estimates"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The key bit of intuition is that the greater the weight of a colour on a data point, the more the data point influences the next estimates for that colour's parameters. This has the effect of \"pulling\" the parameters in the right direction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def estimate_mean(data, weight):\n",
    "    return np.sum(data * weight) / np.sum(weight)\n",
    "\n",
    "\n",
    "def estimate_std(data, weight, mean):\n",
    "    variance = np.sum(weight * (data - mean) ** 2) / np.sum(weight)\n",
    "    return np.sqrt(variance)\n",
    "\n",
    "\n",
    "# new estimates for standard deviation\n",
    "new_blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)\n",
    "new_red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)\n",
    "\n",
    "# new estimates for mean\n",
    "new_red_mean_guess = estimate_mean(both_colours, red_weight)\n",
    "new_blue_mean_guess = estimate_mean(both_colours, blue_weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\n",
    "    \"Red Mean New: {} and Blue Mean New: {}\".format(\n",
    "        new_red_mean_guess, new_blue_mean_guess\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5 (Repeat)\n",
    "### Lets Run for 5 iteration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# estimates for the mean\n",
    "red_mean_guess = 1.1\n",
    "blue_mean_guess = 9\n",
    "\n",
    "# estimates for the standard deviation\n",
    "red_std_guess = 2\n",
    "blue_std_guess = 1.7\n",
    "\n",
    "\n",
    "y = [0.001 for i in range(len(both_colours))]\n",
    "plt.plot(both_colours, y, \"mo\")\n",
    "plt.ylim(-0.02, 0.40)\n",
    "\n",
    "for i in range(5):\n",
    "    likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)\n",
    "    likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)\n",
    "\n",
    "    likelihood_total = likelihood_of_red + likelihood_of_blue\n",
    "\n",
    "    red_weight = likelihood_of_red / likelihood_total\n",
    "    blue_weight = likelihood_of_blue / likelihood_total\n",
    "\n",
    "    blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)\n",
    "    red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)\n",
    "\n",
    "    red_mean_guess = estimate_mean(both_colours, red_weight)\n",
    "    blue_mean_guess = estimate_mean(both_colours, blue_weight)\n",
    "\n",
    "    print(red_mean_guess, blue_mean_guess)\n",
    "\n",
    "    x = np.linspace(-3, 12, 100)\n",
    "    ymaxr = max(mlab.normpdf(x, red_mean_guess, red_std_guess)) / 0.39\n",
    "    ymaxb = max(mlab.normpdf(x, blue_mean_guess, blue_std_guess)) / 0.39\n",
    "\n",
    "    plt.plot(\n",
    "        x,\n",
    "        mlab.normpdf(x, red_mean_guess, red_std_guess),\n",
    "        \"r\",\n",
    "        x,\n",
    "        mlab.normpdf(x, blue_mean_guess, blue_std_guess),\n",
    "        \"b\",\n",
    "        alpha=0.2 * i,\n",
    "    )\n",
    "    plt.axvline(\n",
    "        red_mean_guess, color=\"r\", linestyle=\"--\", ymin=0.067, ymax=ymaxr, alpha=0.2 * i\n",
    "    )\n",
    "    plt.axvline(\n",
    "        blue_mean_guess,\n",
    "        color=\"b\",\n",
    "        linestyle=\"--\",\n",
    "        ymin=0.067,\n",
    "        ymax=ymaxb,\n",
    "        alpha=0.2 * i,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see estimated mean is going near the real mean."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# estimates for the mean\n",
    "red_mean_guess = 1.1\n",
    "blue_mean_guess = 9\n",
    "\n",
    "# estimates for the standard deviation\n",
    "red_std_guess = 2\n",
    "blue_std_guess = 1.7\n",
    "\n",
    "for i in range(20):\n",
    "    likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)\n",
    "    likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)\n",
    "\n",
    "    likelihood_total = likelihood_of_red + likelihood_of_blue\n",
    "\n",
    "    red_weight = likelihood_of_red / likelihood_total\n",
    "    blue_weight = likelihood_of_blue / likelihood_total\n",
    "\n",
    "    blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)\n",
    "    red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)\n",
    "\n",
    "    red_mean_guess = estimate_mean(both_colours, red_weight)\n",
    "    blue_mean_guess = estimate_mean(both_colours, blue_weight)\n",
    "\n",
    "y = [0.001 for i in range(len(both_colours))]\n",
    "plt.plot(both_colours, y, \"mo\")\n",
    "plt.ylim(-0.02, 0.50)\n",
    "\n",
    "x = np.linspace(-3, 12, 100)\n",
    "ymaxr = max(mlab.normpdf(x, red_mean_guess, red_std_guess)) / 0.5\n",
    "ymaxb = max(mlab.normpdf(x, blue_mean_guess, blue_std_guess)) / 0.5\n",
    "\n",
    "plt.plot(\n",
    "    x,\n",
    "    mlab.normpdf(x, red_mean_guess, red_std_guess),\n",
    "    \"r\",\n",
    "    x,\n",
    "    mlab.normpdf(x, blue_mean_guess, blue_std_guess),\n",
    "    \"b\",\n",
    "    alpha=1,\n",
    ")\n",
    "plt.axvline(red_mean_guess, color=\"r\", linestyle=\"--\", ymin=0.06, ymax=ymaxr, alpha=1)\n",
    "plt.axvline(blue_mean_guess, color=\"b\", linestyle=\"--\", ymin=0.06, ymax=ymaxb, alpha=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ! pip install prettytable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from prettytable import PrettyTable\n",
    "\n",
    "t = PrettyTable([\"\", \"EM Guess\", \"Actual\", \"Delta\"])\n",
    "t.add_row([\"Red mean\", red_mean_guess, np.mean(red), red_mean_guess - np.mean(red)])\n",
    "t.add_row([\"Red std\", red_std_guess, np.std(red), red_std_guess - np.std(red)])\n",
    "t.add_row(\n",
    "    [\"Blue mean\", blue_mean_guess, np.mean(blue), blue_mean_guess - np.mean(blue)]\n",
    ")\n",
    "t.add_row([\"Blue std\", blue_std_guess, np.std(blue), blue_std_guess - np.std(blue)])\n",
    "print(t)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "### Conclusion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Expecation maximization algorithm should always we there in data scientist toolbox. This tutorial has tried to explain the intuitvely how EM works."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## References"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Intuitive Explanation of EM](https://stackoverflow.com/questions/11808074/what-is-an-intuitive-explanation-of-the-expectation-maximization-technique)\n",
    "\n",
    "[Expectation Maximization Primer](https://www.cmi.ac.in/~madhavan/courses/dmml2018/literature/EM_algorithm_2coin_example.pdf)\n",
    "\n",
    "[How EM works](https://math.stackexchange.com/questions/25111/how-does-expectation-maximization-work)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
